# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""x3d head"""
import mindspore.nn as nn
import mindspore.ops as ops
from mindvision.video.utils.x3d_operators import AvgPool3D
from mindvision.classification.models.head import DenseHead
from mindvision.engine.class_factory import ClassFactory, ModuleType


@ClassFactory.register(ModuleType.HEAD)
class X3DHead(nn.Cell):
    """
    X3D head.
    This layer performs a fully-connected projection during training, when the
    input size is 1x1x1. It performs a convolutional projection during testing
    when the input size is larger than 1x1x1. If the inputs are from multiple
    different pathways, the inputs will be concatenated after pooling.

    Args:
        dim_in (float): the channel dimension C of the input.
        num_classes (int): the channel dimensions of the output.
        pool_size (float): a single entry list of kernel size for
            spatiotemporal pooling for the TxHxW dimensions.
        dropout_rate (float): dropout rate. If equal to 0.0, perform no
            dropout.
        act_func (string): activation function to use. 'softmax': applies
            softmax on the output. 'sigmoid': applies sigmoid on the output.
        inplace_relu (bool): if True, calculate the relu on the original
            input without allocating new memory.
        eps (float): epsilon for batch norm.
        bn_mmt (float): momentum for batch norm. Noted that BN momentum in
            PyTorch = 1 - BN momentum in Caffe2.
        norm_module (nn.Module): nn.Module for the normalization layer. The
            default is nn.BatchNorm3d.
        bn_lin5_on (bool): if True, perform normalization on the features
            before the classifier.

    Returns:
        Tensor

    Examples:
        >>> head = X3DHead(dim_in=192, dim_inner=432, dim_out=2048,
        >>>         num_classes=400, pool_size=[16, 7, 7], dropout_rate=0.5)

    """

    def __init__(self,
                 dim_in,
                 dim_inner,
                 dim_out,
                 num_classes,
                 pool_size,
                 dropout_rate=0.0,
                 act_func="softmax",
                 eps=1e-5,
                 bn_mmt=0.1,
                 norm_module=nn.BatchNorm3d,
                 bn_lin5_on=False,
                 ):
        super(X3DHead, self).__init__()
        self.pool_size = pool_size
        self.dropout_rate = dropout_rate
        self.num_classes = num_classes
        self.act_func = act_func
        self.eps = eps
        self.bn_mmt = bn_mmt
        self.bn_lin5_on = bn_lin5_on
        self._construct_head(dim_in, dim_inner, dim_out, norm_module)

    def _construct_head(self, dim_in, dim_inner, dim_out, norm_module):
        """ construct x3d head """
        self.conv_5 = nn.Conv3d(
            dim_in,
            dim_inner,
            kernel_size=(1, 1, 1),
            stride=(1, 1, 1),
            padding=0,
            has_bias=False,
        )
        self.conv_5_bn = norm_module(
            num_features=dim_inner, eps=self.eps, momentum=self.bn_mmt
        )
        self.conv_5_relu = nn.ReLU()

        # todo: replace AvgPool3d with AdaptiveAvgPool3d
        self.avg_pool = AvgPool3D(tuple(self.pool_size), strides=1)

        self.lin_5 = nn.Conv3d(
            dim_inner,
            dim_out,
            kernel_size=(1, 1, 1),
            stride=(1, 1, 1),
            padding=0,
            has_bias=False,
        )
        if self.bn_lin5_on:
            self.lin_5_bn = norm_module(
                num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
            )

        self.lin_5_relu = nn.ReLU()

        # Perform FC in a fully convolutional manner. The FC layer will be
        # initialized with a different std comparing to convolutional layers.
        self.dense = DenseHead(input_channel=dim_out, num_classes=self.num_classes, has_bias=True,
                               keep_prob=self.dropout_rate)

        # Softmax for evaluation and testing.
        if self.act_func == "softmax":
            self.act = nn.Softmax(axis=4)
        elif self.act_func == "sigmoid":
            self.act = nn.Sigmoid()
        else:
            raise NotImplementedError(
                "{} is not supported as an activation"
                "function.".format(self.act_func)
            )

    def construct(self, x):
        """ build x3d head """
        x = self.conv_5(x)
        x = self.conv_5_bn(x)
        x = self.conv_5_relu(x)
        x = self.avg_pool(x)

        x = self.lin_5(x)
        if self.bn_lin5_on:
            x = self.lin_5_bn(x)
        x = self.lin_5_relu(x)

        # (N, C, T, H, W) -> (N, T, H, W, C).
        transpose = ops.Transpose()
        x = transpose(x, (0, 2, 3, 4, 1))

        x = self.dense(x)

        # Performs fully convlutional inference.
        if not self.training:
            x = self.act(x)
            x = x.mean([1, 2, 3])

        x = x.view(x.shape[0], -1)
        return x
